{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GPT-2 Secure inference with Puma\n",
    "\n",
    "In this lab, we showcase how to run 3PC secure inference on a pre-trained [GPT-2](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) model for text generation with Puma.\n",
    "\n",
    "First, we show how to use JAX and the Hugging Face Transformers library for text generation with the pre-trained GPT-2 model. After that, we show how to use Puma on the top of SPU for secure text generation with minor modifications to the plaintext counterpart. "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> This tutorial may need more resources than 16c48g."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What is Puma?\n",
    "\n",
    "Puma is a fast and accurate end-to-end 3-party secure Transformer models inference framework. \n",
    "Puma designs high quality approximations for expensive functions, such as $\\mathsf{GeLU}$ and $\\mathsf{Softmax}$, which significantly reduce the cost of secure inference while preserving the model performance. Additionally, we design secure $\\mathsf{Embedding}$ and $\\mathsf{LayerNorm}$ procedures that faithfully implement the desired functionality without undermining the Transformer architecture.\n",
    "Puma is approximately $2\\times$ faster than the state-of-the-art MPC framework MPCFormer (ICLR 2023) and has similar accuracy as plaintext models without fine-tuning (which the previous works failed to achieve).  "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Text generation using GPT-2 with JAX/FLAX\n",
    "### Install the transformers library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "!{sys.executable} -m pip install transformers[flax]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">The JAX version required by transformers is not satisfied with SPU. But it's ok to run with the conflicted JAX with SPU in this example."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the pre-trained GPT-2 Model\n",
    "\n",
    "Please refer to this [documentation](https://huggingface.co/docs/transformers/main/en/model_doc/gpt2) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, FlaxGPT2LMHeadModel, GPT2Config\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
    "pretrained_model = FlaxGPT2LMHeadModel.from_pretrained(\"gpt2\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To hack GeLU function of GPT2, you need to change the `self.act` as `jax.nn.gelu` to hack `gelu`.\n",
    "For example, in `transformers/src/transformers/models/gpt2/modeling_flax_gpt2.py`, line 296:\n",
    "\n",
    "```python\n",
    "hidden_states = self.act(hidden_states)\n",
    "```\n",
    "\n",
    "is changed as\n",
    "\n",
    "```python\n",
    "hidden_states = jax.nn.gelu(hidden_states)\n",
    "```\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the text generation function\n",
    "\n",
    "\n",
    "We use a [greedy search strategy](https://huggingface.co/blog/how-to-generate) for text generation here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def text_generation(input_ids, params):\n",
    "    config = GPT2Config()\n",
    "    model = FlaxGPT2LMHeadModel(config=config)\n",
    "\n",
    "    for _ in range(10):\n",
    "        outputs = model(input_ids=input_ids, params=params)\n",
    "        next_token_logits = outputs[0][0, -1, :]\n",
    "        next_token = jnp.argmax(next_token_logits)\n",
    "        input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])], axis=1)\n",
    "    return input_ids"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run text generation on CPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-06-15 17:07:55.627043: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst\n",
      "2023-06-15 17:07:55.627112: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst\n",
      "2023-06-15 17:07:55.627118: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------------------------------------------------------------\n",
      "Run on CPU:\n",
      "-----------------------------------------------------------------\n",
      "I enjoy walking with my cute dog, but I'm not sure if I'll ever\n",
      "-----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "inputs_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')\n",
    "outputs_ids = text_generation(inputs_ids, pretrained_model.params)\n",
    "\n",
    "print('-' * 65 + '\\nRun on CPU:\\n' + '-' * 65)\n",
    "print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))\n",
    "print('-' * 65)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we generate 10 tokens. Keep the generated text in mind, we are going to generate text on SPU in the next step."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### Run text generation on SPU"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Import the necessary libraries and config the optimizations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import secretflow as sf\n",
    "from typing import Any, Callable, Dict, Optional, Tuple, Union\n",
    "import jax.nn as jnn\n",
    "import flax.linen as nn\n",
    "from flax.linen.linear import Array\n",
    "import jax\n",
    "import argparse\n",
    "import spu.utils.distributed as ppd\n",
    "import spu.intrinsic as intrinsic\n",
    "import spu\n",
    "from contextlib import contextmanager\n",
    "\n",
    "copts = spu.CompilerOptions()\n",
    "copts.enable_pretty_print = False\n",
    "copts.xla_pp_kind = spu.libspu.XLAPrettyPrintKind.HTML\n",
    "# enable x / broadcast(y) -> x * broadcast(1/y)\n",
    "copts.enable_optimize_denominator_with_broadcast = True\n",
    "Array = Any\n",
    "\n",
    "# In case you have a running secretflow runtime already.\n",
    "sf.shutdown()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Define the Softmax hijack function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hack_softmax(\n",
    "    x: Array,\n",
    "    axis: Optional[Union[int, Tuple[int, ...]]] = -1,\n",
    "    where: Optional[Array] = None,\n",
    "    initial: Optional[Array] = None,\n",
    ") -> Array:\n",
    "    x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)\n",
    "    x = x - x_max\n",
    "\n",
    "    # exp on large negative is clipped to zero\n",
    "    b = x > -14\n",
    "    nexp = jnp.exp(x)\n",
    "\n",
    "    divisor = jnp.sum(nexp, axis, where=where, keepdims=True)\n",
    "\n",
    "    return b * (nexp / divisor)\n",
    "\n",
    "\n",
    "@contextmanager\n",
    "def hack_softmax_context(msg: str, enabled: bool = False):\n",
    "    if not enabled:\n",
    "        yield\n",
    "        return\n",
    "    # hijack some target functions\n",
    "    raw_softmax = jnn.softmax\n",
    "    jnn.softmax = hack_softmax\n",
    "    yield\n",
    "    # recover back\n",
    "    jnn.softmax = raw_softmax"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Define the GeLU hijack function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hack_gelu(\n",
    "    x: Array,\n",
    "    axis: Optional[Union[int, Tuple[int, ...]]] = -1,\n",
    "    where: Optional[Array] = None,\n",
    "    initial: Optional[Array] = None,\n",
    ") -> Array:\n",
    "    b0 = x < -4.0\n",
    "    b1 = x < -1.95\n",
    "    b2 = x > 3.0\n",
    "    b3 = b1 ^ b2 ^ True  # x in [-1.95, 3.0]\n",
    "    b4 = b0 ^ b1  # x in [-4, -1.95]\n",
    "\n",
    "    # seg1 = a[3] * x^3 + a[2] * x^2 + a[1] * x + a[0]\n",
    "    # seg2 = b[6] * x^6 + b[4] * x^4 + b[2] * x^2 + b[1] * x + b[0]\n",
    "    a_coeffs = jnp.array(\n",
    "        [\n",
    "            -0.5054031199708174,\n",
    "            -0.42226581151983866,\n",
    "            -0.11807612951181953,\n",
    "            -0.011034134030615728,\n",
    "        ]\n",
    "    )\n",
    "    b_coeffs = jnp.array(\n",
    "        [\n",
    "            0.008526321541038084,\n",
    "            0.5,\n",
    "            0.3603292692789629,\n",
    "            0.0,\n",
    "            -0.037688200365904236,\n",
    "            0.0,\n",
    "            0.0018067462606141187,\n",
    "        ]\n",
    "    )\n",
    "    x2 = jnp.square(x)\n",
    "    x3 = jnp.multiply(x, x2)\n",
    "    x4 = jnp.square(x2)\n",
    "    x6 = jnp.square(x3)\n",
    "\n",
    "    seg1 = a_coeffs[3] * x3 + a_coeffs[2] * x2 + a_coeffs[1] * x + a_coeffs[0]\n",
    "    seg2 = (\n",
    "        b_coeffs[6] * x6\n",
    "        + b_coeffs[4] * x4\n",
    "        + b_coeffs[2] * x2\n",
    "        + b_coeffs[1] * x\n",
    "        + b_coeffs[0]\n",
    "    )\n",
    "\n",
    "    ret = b2 * x + b4 * seg1 + b3 * seg2\n",
    "\n",
    "    return ret\n",
    "\n",
    "\n",
    "@contextmanager\n",
    "def hack_gelu_context(msg: str, enabled: bool = False):\n",
    "    if not enabled:\n",
    "        yield\n",
    "        return\n",
    "    # hijack some target functions\n",
    "    raw_gelu = jnn.gelu\n",
    "    jnn.gelu = hack_gelu\n",
    "    yield\n",
    "    # recover back\n",
    "    jnn.gelu = raw_gelu"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Launch Puma on GPT2:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-06-15 17:08:14,157\tINFO worker.py:1538 -- Started a local Ray instance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(pid=2109508)\u001b[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.\n",
      "\u001b[2m\u001b[36m(pid=2109408)\u001b[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.\n",
      "\u001b[2m\u001b[36m(pid=2121303)\u001b[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.\n",
      "\u001b[2m\u001b[36m(pid=2121304)\u001b[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.\n",
      "\u001b[2m\u001b[36m(pid=2121301)\u001b[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.\n",
      "\u001b[2m\u001b[36m(_run pid=2109408)\u001b[0m INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n",
      "\u001b[2m\u001b[36m(_run pid=2109408)\u001b[0m INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host\n",
      "\u001b[2m\u001b[36m(_run pid=2109408)\u001b[0m INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.\n",
      "\u001b[2m\u001b[36m(_run pid=2109408)\u001b[0m WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
      "\u001b[2m\u001b[36m(_run pid=2109508)\u001b[0m INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n",
      "\u001b[2m\u001b[36m(_run pid=2109508)\u001b[0m INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host\n",
      "\u001b[2m\u001b[36m(_run pid=2109508)\u001b[0m INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.\n",
      "\u001b[2m\u001b[36m(_run pid=2109508)\u001b[0m WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(_run pid=2109408)\u001b[0m [2023-06-15 17:08:24.221] [info] [thread_pool.cc:30] Create a fixed thread pool with size 127\n"
     ]
    }
   ],
   "source": [
    "sf.init(['alice', 'bob', 'carol'], address='local')\n",
    "\n",
    "alice, bob = sf.PYU('alice'), sf.PYU('bob')\n",
    "conf = sf.utils.testing.cluster_def(['alice', 'bob', 'carol'])\n",
    "conf['runtime_config']['protocol'] = 'ABY3'\n",
    "conf['runtime_config']['field'] = 'FM64'\n",
    "conf['runtime_config']['fxp_exp_mode'] = 0\n",
    "conf['runtime_config']['fxp_exp_iters'] = 5\n",
    "\n",
    "spu = sf.SPU(conf)\n",
    "\n",
    "\n",
    "def get_model_params():\n",
    "    pretrained_model = FlaxGPT2LMHeadModel.from_pretrained(\"gpt2\")\n",
    "    return pretrained_model.params\n",
    "\n",
    "\n",
    "def get_token_ids():\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
    "    return tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')\n",
    "\n",
    "\n",
    "model_params = alice(get_model_params)()\n",
    "input_token_ids = bob(get_token_ids)()\n",
    "\n",
    "device = spu\n",
    "model_params_, input_token_ids_ = model_params.to(device), input_token_ids.to(device)\n",
    "\n",
    "with hack_softmax_context(\"hijack jax softmax\", enabled=True), hack_gelu_context(\n",
    "    \"hack jax gelu\", enabled=True\n",
    "):\n",
    "    output_token_ids = spu(text_generation, copts=copts)(\n",
    "        input_token_ids_, model_params_\n",
    "    )"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check the Puma output\n",
    "\n",
    "As you can see, it's very easy to run GPT-2 inference on Puma. Now let's reveal the generated text from Puma."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(_spu_compile pid=2109408)\u001b[0m 2023-06-15 17:09:12.722333: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst\n",
      "\u001b[2m\u001b[36m(_spu_compile pid=2109408)\u001b[0m 2023-06-15 17:09:12.722414: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst\n",
      "\u001b[2m\u001b[36m(_spu_compile pid=2109408)\u001b[0m 2023-06-15 17:09:12.722421: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(SPURuntime(device_id=None, party=bob) pid=2121303)\u001b[0m 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127\n",
      "\u001b[2m\u001b[36m(SPURuntime(device_id=None, party=alice) pid=2121301)\u001b[0m 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127\n",
      "\u001b[2m\u001b[36m(SPURuntime(device_id=None, party=carol) pid=2121304)\u001b[0m 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127\n",
      "-----------------------------------------------------------------\n",
      "Run on SPU:\n",
      "-----------------------------------------------------------------\n",
      "I enjoy walking with my cute dog, but I'm not sure if I'll ever\n",
      "-----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "outputs_ids = sf.reveal(output_token_ids)\n",
    "print('-' * 65 + '\\nRun on SPU:\\n' + '-' * 65)\n",
    "print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))\n",
    "print('-' * 65)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, the generated text from Puma is exactly same as the generated text from CPU!"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the end of the lab.\n",
    "For more benchmarks about Puma, please refer to: https://github.com/secretflow/spu/tree/main/examples/python/ml/flax_llama7b"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  },
  "vscode": {
   "interpreter": {
    "hash": "db45a4cb4cd37a8de684dfb7fcf899b68fccb8bd32d97c5ad13e5de1245c0986"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
